import torch

from src.utils.trajectories import Trajectories, TrajectoryBuffer


class ReplayBuffer():
    """ 
    Replay buffer for offline training.

    When the buffer is full, new samples overwrite the oldest samples.

    Fields
    ------
    buffer: List of Experience objects.
    """
    
    def __init__(self, config, env, gfn):        
        self.config = config
        self.env = env
        self.gfn = gfn
        self.device = config["device"]
        self.batch_size = self.config["batch_size"]
        self.reward_only = self.config["replay_buffer"]["reward_only"]
        self.buffer = TrajectoryBuffer(capacity=config["replay_buffer"]["capacity"], trajectory_length=config["gfn"]["trajectory_length"], dim=env.dim, batch_size=self.batch_size, device=self.config["device"], reward_only=self.reward_only)
        self.beta = self.config["replay_buffer"]["beta"]
        self.alpha = self.config["replay_buffer"]["alpha"]
        self.log_reward_threshold = self.config["replay_buffer"]["log_reward_threshold"]

        assert self.config["replay_buffer"]["sampling_method"] in ["uniform_random", "biased", "reward_proportional"]

    def __len__(self):
        return len(self.buffer)

    def batch_push(self, trajectories):
        """Saves a batch of transitions."""
        trajectories.prune(self.log_reward_threshold)
        self.buffer.extend(trajectories)

    def sample(self, batch_size=None):
        """
        Sample trajectories from the replay buffer.
        """
        sm = self.config["replay_buffer"]["sampling_method"]
        if sm == "uniform_random":
            trajs = self.uniform_random_sample(batch_size)
        elif sm == "biased":
            trajs = self.biased_sample(batch_size)
        elif sm == "reward_proportional":
            trajs = self.reward_proportional_sample(batch_size)
        else:
            raise ValueError("Sampling method not recognised.")

        return trajs

    def _get_trajectories(self, indices, batch_size=None):
        """
        Get a batch of trajectories from the buffer according to the indices.
        """
        sliced_trajs = self.buffer.slice_trajs(indices)
        if self.reward_only:
            states, actions = self.gfn.backward_sample_trajectories(sliced_trajs.get_final_states(), batch_size)
            sliced_trajs = Trajectories(states, actions, sliced_trajs.log_rewards)

            return sliced_trajs
        else:
            return sliced_trajs
    
    def uniform_random_sample(self, batch_size):
        """
        Samples a uniform random batch of trajectories.
        """
        if batch_size is None:
            batch_size = self.batch_size
        indices = torch.randint(0, self.buffer.stored_trajectories, (batch_size,))

        if indices.numel() == 0:
            return None

        return self._get_trajectories(indices, batch_size)
    
    def biased_sample(self, batch_size):
        """
        Sample alpha fraction of trajectories from the beta fraction of trajectories with highest reward,
        and 1-alpha fraction of trajectories from the remaining trajectories.
        """
        if batch_size is None:
            batch_size = self.batch_size
        threshold = torch.kthvalue(self.buffer.log_rewards[:self.buffer.stored_trajectories], int((1-self.beta)*self.buffer.stored_trajectories)).values
        top_indices = torch.nonzero(self.buffer.log_rewards[:self.buffer.stored_trajectories] >= threshold).squeeze()
        bottom_indices = torch.nonzero(self.buffer.log_rewards[:self.buffer.stored_trajectories] < threshold).squeeze()

        # Check if either top_indices or bottom_indices are empty or are dimension 0 
        if top_indices.numel() == 0 or bottom_indices.numel() == 0 or top_indices.numel() == 1 or bottom_indices.numel() == 1:
            return None

        num_top = int(self.alpha * batch_size)
        num_bottom = batch_size - num_top

        top_indices = top_indices[torch.randint(len(top_indices), (num_top,))]
        bottom_indices = bottom_indices[torch.randint(len(bottom_indices), (num_bottom,))]


        indices = torch.cat((top_indices, bottom_indices))

        return self._get_trajectories(indices, batch_size)
    
    def reward_proportional_sample(self, batch_size):
        """
        Sample a batch of trajectories with probability proportional to reward.
        """
        if batch_size is None:
            batch_size = self.batch_size
        probs = torch.exp(self.buffer.log_rewards - self.buffer.log_rewards.max())
        probs /= probs.sum()
        indices = torch.multinomial(probs, batch_size, replacement=True)

        if indices.numel() == 0:
            return None

        return self._get_trajectories(indices, batch_size)
